from __future__ import annotations

import typing as t

from sqlglot import expressions as exp


INDEX = "i"
ARG_KEY = "k"
IS_ARRAY = "a"
CLASS = "c"
TYPE = "t"
COMMENTS = "o"
META = "m"
VALUE = "v"
DATA_TYPE = "DataType.Type"


def dump(expression: exp.Expression) -> t.List[t.Dict[str, t.Any]]:
    """
    Dump an Expression into a JSON serializable List.
    """
    i = 0
    payloads = []
    stack: t.List[t.Tuple[t.Any, t.Optional[int], t.Optional[str], bool]] = [
        (expression, None, None, False)
    ]

    while stack:
        node, index, arg_key, is_array = stack.pop()

        payload: t.Dict[str, t.Any] = {}

        if index is not None:
            payload[INDEX] = index
        if arg_key is not None:
            payload[ARG_KEY] = arg_key
        if is_array:
            payload[IS_ARRAY] = is_array

        payloads.append(payload)

        if hasattr(node, "parent"):
            klass = node.__class__.__qualname__

            if node.__class__.__module__ != exp.__name__:
                klass = f"{node.__module__}.{klass}"

            payload[CLASS] = klass

            if node.type:
                payload[TYPE] = dump(node.type)
            if node.comments:
                payload[COMMENTS] = node.comments
            if node._meta is not None:
                payload[META] = node._meta
            if node.args:
                for k, vs in reversed(node.args.items()):
                    if type(vs) is list:
                        for v in reversed(vs):
                            stack.append((v, i, k, True))
                    elif vs is not None:
                        stack.append((vs, i, k, False))
        elif type(node) is exp.DataType.Type:
            payload[CLASS] = DATA_TYPE
            payload[VALUE] = node.value
        else:
            payload[VALUE] = node

        i += 1

    return payloads


@t.overload
def load(payloads: None) -> None: ...


@t.overload
def load(payloads: t.List[t.Dict[str, t.Any]]) -> exp.Expression: ...


def load(payloads):
    """
    Load a list of dicts generated by dump into an Expression.
    """

    if not payloads:
        return None

    payload, *tail = payloads
    root = _load(payload)
    nodes = [root]
    for payload in tail:
        node = _load(payload)
        nodes.append(node)
        parent = nodes[payload[INDEX]]
        arg_key = payload[ARG_KEY]

        if payload.get(IS_ARRAY):
            parent.append(arg_key, node)
        else:
            parent.set(arg_key, node)

    return root


def _load(payload: t.Dict[str, t.Any]) -> exp.Expression | exp.DataType.Type:
    class_name = payload.get(CLASS)

    if not class_name:
        return payload[VALUE]
    if class_name == DATA_TYPE:
        return exp.DataType.Type(payload[VALUE])

    if "." in class_name:
        module_path, class_name = class_name.rsplit(".", maxsplit=1)
        module = __import__(module_path, fromlist=[class_name])
    else:
        module = exp

    expression = getattr(module, class_name)()
    expression.type = load(payload.get(TYPE))
    expression.comments = payload.get(COMMENTS)
    expression._meta = payload.get(META)
    return expression
